import os
from collections.abc import Iterator
from functools import partial
from typing import Any, Literal, get_args

import httpx
import pytest
from typing_extensions import TypedDict

from pydantic_ai.models import KnownModelName
from pydantic_ai.providers.gateway import ModelProvider as GatewayModelProvider

from ..conftest import try_import

with try_import() as imports_successful:
    from pydantic_ai.models.anthropic import AnthropicModelName
    from pydantic_ai.models.bedrock import BedrockModelName
    from pydantic_ai.models.cohere import CohereModelName
    from pydantic_ai.models.google import GoogleModelName
    from pydantic_ai.models.groq import GroqModelName
    from pydantic_ai.models.huggingface import HuggingFaceModelName
    from pydantic_ai.models.mistral import MistralModelName
    from pydantic_ai.models.openai import OpenAIModelName
    from pydantic_ai.providers.deepseek import DeepSeekModelName
    from pydantic_ai.providers.grok import GrokModelName
    from pydantic_ai.providers.moonshotai import MoonshotAIModelName

pytestmark = [
    pytest.mark.skipif(not imports_successful(), reason='some model package was not installed'),
    pytest.mark.vcr,
]


def modify_response(response: dict[str, Any], filter_headers: list[str]) -> dict[str, Any]:  # pragma: lax no cover
    for header in response['headers'].copy():
        assert isinstance(header, str)
        if header.lower() in filter_headers:
            del response['headers'][header]
    return response


@pytest.fixture(scope='module')
def vcr_config():  # pragma: lax no cover
    if os.getenv('CI') or not os.getenv('CEREBRAS_API_KEY'):
        return {'record_mode': 'none'}

    return {
        'record_mode': 'rewrite',
        'filter_headers': ['accept-encoding'],
        'before_record_response': partial(modify_response, filter_headers=['cache-control', 'connection']),
    }


_PROVIDER_TO_MODEL_NAMES = {
    'anthropic': AnthropicModelName,
    'bedrock': BedrockModelName,
    'cohere': CohereModelName,
    'deepseek': DeepSeekModelName,
    'google-gla': GoogleModelName,
    'google-vertex': GoogleModelName,
    'grok': GrokModelName,
    'groq': GroqModelName,
    'huggingface': HuggingFaceModelName,
    'mistral': MistralModelName,
    'moonshotai': MoonshotAIModelName,
    'openai': OpenAIModelName,
}


def test_known_model_names():  # pragma: lax no cover
    # Coverage seems to be misbehaving..?
    def get_model_names(model_name_type: Any) -> Iterator[str]:
        for arg in get_args(model_name_type):
            if isinstance(arg, str):
                yield arg
            else:
                yield from get_model_names(arg)

    all_generated_names = [
        f'{provider}:{n}'
        for provider, model_names in _PROVIDER_TO_MODEL_NAMES.items()
        for n in get_model_names(model_names)
    ]

    cerebras_names = get_cerebras_model_names()
    heroku_names = get_heroku_model_names()
    gateway_names = [
        f'gateway/{provider}:{model_name}'
        for provider in GatewayModelProvider.__args__
        for model_name in get_model_names(_PROVIDER_TO_MODEL_NAMES[provider])
    ]

    extra_names = ['test']

    generated_names = sorted(all_generated_names + gateway_names + heroku_names + cerebras_names + extra_names)

    known_model_names = sorted(get_args(KnownModelName.__value__))
    assert generated_names == known_model_names


class HerokuModel(TypedDict):
    model_id: str
    regions: list[str]
    type: list[str]


def get_heroku_model_names():
    response = httpx.get('https://us.inference.heroku.com/available-models')

    if response.status_code != 200:
        pytest.skip(f'Heroku AI returned status code {response.status_code}')  # pragma: lax no cover

    heroku_models: list[HerokuModel] = response.json()

    models: list[str] = []
    for model in heroku_models:
        if 'text-to-text' in model['type']:
            models.append(f'heroku:{model["model_id"]}')
    return sorted(models)


class CerebrasModel(TypedDict):
    created: int
    id: str
    object: Literal['model']
    owned_by: Literal['Cerebras']


def get_cerebras_model_names():  # pragma: lax no cover
    api_key = os.getenv('CEREBRAS_API_KEY', 'testing')

    response = httpx.get(
        'https://api.cerebras.ai/v1/models',
        headers={'Authorization': f'Bearer {api_key}', 'Accept': 'application/json', 'Accept-Encoding': 'identity'},
    )

    if response.status_code != 200:
        pytest.skip(f'Cerebras returned status code {response.status_code}')  # pragma: lax no cover

    cerebras_models: list[CerebrasModel] = response.json()['data']
    return sorted(f'cerebras:{model["id"]}' for model in cerebras_models)
